/**
 * Copyright 2023-2023 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "cast_check_utils.h"

#include <list>

#include "infra/base/assertion.h"
#include "framework/infra/log/log.h"

#include "graph/op/math_defs.h"
#include "framework/graph/core/node/node_spec.h"
#include "framework/graph/core/node/node_functor.h"
#include "framework/graph/core/cgraph/graph_list_walker.h"

using namespace ge;

namespace hiai {
namespace {
struct DataTypePair {
    DataType_t xType;
    DataType_t yType;
};

struct FormatPair {
    ge::Format inputFormat;
    ge::Format outputFormat;
};

static const vector<DataTypePair> supportDataTypeList {
    {CC_DATA_FLOAT, CC_DATA_HALF}, {CC_DATA_HALF, CC_DATA_FLOAT},
    {CC_DATA_UINT8, CC_DATA_FLOAT}, {CC_DATA_QUINT8, CC_DATA_FLOAT},
    {CC_DATA_INT8, CC_DATA_FLOAT}, {CC_DATA_INT64, CC_DATA_INT32},
    {CC_DATA_INT64, CC_DATA_FLOAT}, {CC_DATA_HALF, CC_DATA_UINT8},
    {CC_DATA_INT32, CC_DATA_FLOAT}
};

static const vector<FormatPair> supportFormatNeedCheckList {
    {FORMAT_NC1HWC0, FORMAT_NCHW},
    {FORMAT_NC1HWC0, FORMAT_NHWC},
    {FORMAT_NHWC, FORMAT_NC1HWC0},
    {FORMAT_NCHW, FORMAT_NC1HWC0},
    {FORMAT_NHWC, FORMAT_NCHW}
};

static const vector<FormatPair> supportFormatList {
    {FORMAT_ND, FORMAT_NCHW},
    {FORMAT_NCHW, FORMAT_ND},
    {FORMAT_NHWC, FORMAT_ND},
    {FORMAT_ND, FORMAT_NC1HWC0},
    {FORMAT_NC1HWC0, FORMAT_ND}
};
} // namespace

bool CastCheckUtils::CheckDataType(DataType_t xType, DataType_t yType)
{
    if (xType == yType) {
        return true;
    }

    DataTypePair dataTypePair {xType, yType};
    auto it = find_if(supportDataTypeList.cbegin(), supportDataTypeList.cend(),
        [dataTypePair](DataTypePair supportPair) {
        return dataTypePair.xType == supportPair.xType && dataTypePair.yType == supportPair.yType;
    });
    if (it == supportDataTypeList.cend()) {
        FMK_LOGE("TransDataType from %d to %d is not supported.", xType, yType);
        return false;
    }
    return true;
}

bool CastCheckUtils::CheckFormat(const ge::TensorDesc& inputDesc, const ge::TensorDesc& outputDesc)
{
    auto inputFormat = inputDesc.GetFormat();
    auto outputFormat = outputDesc.GetFormat();
    if (inputFormat == outputFormat) {
        return true;
    }
    FormatPair formatPair {inputFormat, outputFormat};
    using FormatMatchFun = std::function<bool(FormatPair)>;
    FormatMatchFun formatMatchFun = [formatPair](FormatPair supportFormat) {
        return formatPair.inputFormat == supportFormat.inputFormat &&
               formatPair.outputFormat == supportFormat.outputFormat;
    };
    auto itNeedCheck = find_if(supportFormatNeedCheckList.cbegin(), supportFormatNeedCheckList.cend(), formatMatchFun);
    if (itNeedCheck != supportFormatNeedCheckList.cend()) {
        return CheckDims(inputDesc, outputDesc);
    }

    auto it = find_if(supportFormatList.cbegin(), supportFormatList.cend(), formatMatchFun);
    if (it != supportFormatList.cend()) {
        return true;
    }
    FMK_LOGE("TransTensor:from format:[%d] to [%d] is not be supported!", inputFormat, outputFormat);
    return false;
}

hiai::Status CastCheckUtils::TransferDim(const std::vector<int64_t>& dims, std::vector<int64_t>& newDims)
{
    uint32_t inputShapeSize = dims.size();
    std::list<int64_t> newDimList;

    for (auto d : dims) {
        newDimList.push_back(d);
    }

    if (inputShapeSize > 5) {
        FMK_LOGE("Cannot support inputShapeSize %u", inputShapeSize);
        return hiai::FAILED;
    }
    switch (inputShapeSize) {
        case 0: {
            newDimList.push_back(1);
            newDimList.push_back(1);
            newDimList.push_back(1);
            newDimList.push_back(1);
            break;
        }
        case 1: {
            newDimList.push_front(1);
            newDimList.push_back(1);
            newDimList.push_back(1);
            break;
        }
        case 2: {
            newDimList.push_front(1);
            newDimList.push_back(1);
            break;
        }
        case 3: {
            newDimList.push_front(1);
            break;
        }
        default:
            break;
    }

    newDims.clear();
    for (auto d : newDimList) {
        newDims.push_back(d);
    }
    return hiai::SUCCESS;
}

bool CastCheckUtils::CheckDimValue(ge::Format inputFormat, ge::Format outputFormat,
    const std::vector<int64_t>& inputDims, const std::vector<int64_t>& outputDims)
{
    if ((inputFormat == FORMAT_NC1HWC0 && outputFormat == FORMAT_NCHW) ||
        (inputFormat == FORMAT_NCHW && outputFormat == FORMAT_NC1HWC0)) {
        return inputDims == outputDims;
    }
    if (inputFormat == FORMAT_NC1HWC0 && outputFormat == FORMAT_NHWC) {
        HIAI_EXPECT_TRUE_R((inputDims[0] == outputDims[0]), false);
        HIAI_EXPECT_TRUE_R((inputDims[1] == outputDims[3]), false);
        HIAI_EXPECT_TRUE_R((inputDims[2] == outputDims[1]), false);
        HIAI_EXPECT_TRUE_R((inputDims[3] == outputDims[2]), false);
        return true;
    }
    if ((inputFormat == FORMAT_NHWC && outputFormat == FORMAT_NCHW) ||
        (inputFormat == FORMAT_NHWC && outputFormat == FORMAT_NC1HWC0)) {
        HIAI_EXPECT_TRUE_R((outputDims[0] == inputDims[0]), false);
        HIAI_EXPECT_TRUE_R((outputDims[1] == inputDims[3]), false);
        HIAI_EXPECT_TRUE_R((outputDims[2] == inputDims[1]), false);
        HIAI_EXPECT_TRUE_R((outputDims[3] == inputDims[2]), false);
        return true;
    }
    return true;
}

bool CastCheckUtils::CheckDims(const ge::TensorDesc& inputDesc, const ge::TensorDesc& outputDesc)
{
    // complete dim and check
    std::vector<int64_t> originInputDims = inputDesc.GetShape().GetDims();
    std::vector<int64_t> originOutputDims = outputDesc.GetShape().GetDims();
    std::vector<int64_t> newInputDims;
    std::vector<int64_t> newOutputDims;
    HIAI_EXPECT_TRUE_R(TransferDim(originInputDims, newInputDims) == hiai::SUCCESS, false);
    HIAI_EXPECT_TRUE_R(TransferDim(originOutputDims, newOutputDims) == hiai::SUCCESS, false);

    auto inputFormat = inputDesc.GetFormat();
    auto outputFormat = outputDesc.GetFormat();
    HIAI_EXPECT_TRUE_R(CheckDimValue(inputFormat, outputFormat, newInputDims, newOutputDims), false);
    return true;
}

bool CastCheckUtils::CheckInputOutputTensor(const ge::TensorDesc& inputDesc, const ge::TensorDesc& outputDesc)
{
    int32_t inputDataType = static_cast<int32_t>(inputDesc.GetDataType());
    int32_t outputDataTpe = static_cast<int32_t>(outputDesc.GetDataType());
    HIAI_EXPECT_IN_RANGE_R(inputDataType, CC_DATA_FLOAT, CC_DATA_RESERVED - 1, false);
    HIAI_EXPECT_IN_RANGE_R(outputDataTpe, CC_DATA_FLOAT, CC_DATA_RESERVED - 1, false);
    HIAI_EXPECT_TRUE_R(CheckDataType(tagDataType(inputDataType), tagDataType(outputDataTpe)), false);

    HIAI_EXPECT_IN_RANGE_R(inputDesc.GetFormat(), FORMAT_NCHW, FORMAT_RESERVED - 1, false);
    HIAI_EXPECT_IN_RANGE_R(outputDesc.GetFormat(), FORMAT_NCHW, FORMAT_RESERVED - 1, false);
    HIAI_EXPECT_TRUE_R(CheckFormat(inputDesc, outputDesc), false);
    return true;
}

hiai::Status CastCheckUtils::CheckInputOutput(const ge::Node& node)
{
    int32_t inputSize = node.ROLE(NodeSpec).OpDesc().GetInputsDescSize();
    for (int32_t i = 0; i < inputSize; i++) {
        auto& inputDesc = node.ROLE(NodeSpec).OpDesc().GetInputDesc(i);
        auto& outputDesc = node.ROLE(NodeSpec).OpDesc().GetOutputDesc(i);
        HIAI_EXPECT_TRUE_R(CheckInputOutputTensor(inputDesc, outputDesc), hiai::FAILED);
    }
    return hiai::SUCCESS;
}

hiai::Status CastCheckUtils::CastBuildVerify(std::shared_ptr<ComputeGraph>& computeGraph)
{
    hiai::Status ret = computeGraph->ROLE(GraphListWalker).WalkAllNodes(
        ge::NodeFunctor::Typed({hiai::op::CastT::TYPE}, [](Node& node) {
        hiai::Status checkRet = CheckInputOutput(node);
        return checkRet;
        }));
    return ret;
}
} // namespace hiai
